
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
from .config import settings

# 定义集合名称和相关参数
COLLECTION_NAME = "face_features_collection"
FACE_VECTOR_DIM = 512  # InsightFace 'buffalo_l' 模型的特征维度

class MilvusHelper:
    def __init__(self):
        try:
            # 连接 Milvus 服务
            connections.connect("default", host=settings.MILVUS_HOST, port=settings.MILVUS_PORT)
            print("成功连接到 Milvus 服务。")
        except Exception as e:
            print(f"连接 Milvus 服务失败: {e}")
            raise

    def has_collection(self):
        """检查集合是否存在"""
        return utility.has_collection(COLLECTION_NAME)

    def create_collection(self):
        """创建一个新的集合来存储人脸特征"""
        if self.has_collection():
            print(f"集合 '{COLLECTION_NAME}' 已存在。")
            return

        # 定义字段
        # 主键字段，Milvus 会自动生成ID
        pk_field = FieldSchema(name="feature_id", dtype=DataType.INT64, is_primary=True, auto_id=True)
        # 对应的用户ID字段
        user_id_field = FieldSchema(name="user_id", dtype=DataType.INT64)
        # 人脸特征向量字段
        embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=FACE_VECTOR_DIM)

        # 创建集合 Schema
        schema = CollectionSchema(
            fields=[pk_field, user_id_field, embedding_field],
            description="人脸识别特征集合",
            enable_dynamic_field=False
        )

        # 创建集合
        self.collection = Collection(name=COLLECTION_NAME, schema=schema)
        print(f"集合 '{COLLECTION_NAME}' 创建成功。")

        # 为向量字段创建索引以加速搜索
        index_params = {
            "metric_type": "IP",  # IP (Inner Product) 等价于归一化向量的余弦相似度
            "index_type": "IVF_FLAT",
            "params": {"nlist": 1024} # nlist 的值需要根据数据量调整
        }
        self.collection.create_index(field_name="embedding", index_params=index_params)
        print("向量索引创建成功。")
        return self.collection

    def get_collection(self):
        """获取集合对象并加载到内存中以便搜索"""
        if not self.has_collection():
            self.create_collection()
        collection = Collection(COLLECTION_NAME)
        collection.load()
        return collection

# 在模块加载时创建一个全局实例
milvus_client = MilvusHelper()
collection = milvus_client.get_collection()
